import fnmatch
from typing import Union

import torch

from .._utils import get_init_params
from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP,
                      LayerNorm, RmsNorm, RowLinear)
from ..layers.moe import MixtureOfExperts
from ..models.modeling_utils import LayerQuantConfig, QuantConfig
from ..parameter import Parameter

# isort: off
from .layers import (
    FP4Linear, FP4RowLinear, FP8Linear, FP8RowLinear, Fp8RowwiseAttention,
    Fp8RowwiseGatedMLP, Fp8RowwiseLayerNorm, Fp8RowwiseMLP, Fp8RowwiseRmsNorm,
    Int8SmoothQuantLinear, Int8SmoothQuantRowLinear, QServeAttention,
    QServeGatedMLP, QServeMLP, QServeRmsNorm, SmoothQuantAttention,
    SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP,
    SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear,
    WeightOnlyGroupwiseQuantRowLinear, WeightOnlyQuantColumnLinear,
    WeightOnlyQuantEmbedding, WeightOnlyQuantRowLinear)
# isort: on
from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode


def quantize_layers(
    model,
    quant_config: QuantConfig,
    quant_map,
    preprocess_init_params=None,
):
    exclude_modules = quant_config.exclude_modules
    if exclude_modules is None:
        exclude_modules = [
            '*lm_head',
            '*router',
            '*vocab_embedding',
            '*position_embedding',
            '*block_embedding',
            '*shared_expert_gate',
        ]

    for name, module, parent in model.named_modules_with_parent():
        module_name = name.rsplit('.', 1)[-1]
        is_excluded = False
        quant_cls = None

        # handle exclusion
        for exclude_module in exclude_modules:
            if fnmatch.fnmatchcase(name, exclude_module):
                is_excluded = True
                break

        # MoE modules are quantized on their constructor, so they must always
        # be re-created with the appropriate quant_mode. When excluded,
        # re-create with quant_mode 0.
        # We need to handle it specially, we may want to redesign MoE implementation
        if isinstance(module, MixtureOfExperts):
            quant_cls = type(module)
        elif not is_excluded:
            for cls in quant_map:
                if isinstance(module, cls):
                    quant_cls = quant_map[cls]
                    break

        if quant_cls:
            init_params = get_init_params(module, quant_cls)
            if isinstance(module, MixtureOfExperts):
                if is_excluded:
                    quant_mode = QuantMode(0)
                else:
                    quant_mode = quant_config.quant_mode
                init_params["quant_mode"] = quant_mode

                # Auto-detect pre_quant_scale based on quant_algo
                # For AWQ-based quantization methods that use pre_quant_scale
                if quant_config.quant_algo in [
                        QuantAlgo.W4A16_AWQ, QuantAlgo.NVFP4_AWQ,
                        QuantAlgo.W4A8_AWQ
                ]:
                    init_params["pre_quant_scale"] = True
            if "bias" in init_params and not isinstance(module,
                                                        MixtureOfExperts):
                init_params["bias"] = init_params["bias"] is not None
            if isinstance(module, ColumnLinear):
                init_params[
                    "out_features"] = module.out_features * module.tp_size
            elif isinstance(module, RowLinear):
                init_params["in_features"] = module.in_features * module.tp_size
            if preprocess_init_params is not None:
                preprocess_init_params(init_params, name, module)
            quant_layer = quant_cls(**init_params)
            if parent is not None:
                setattr(parent, module_name, quant_layer)
            else:
                model = quant_layer

    setattr(model, 'quant_mode', quant_config.quant_mode)
    return model


def weight_only_quantize(model, quant_config: QuantConfig, model_config=None):
    assert quant_config.quant_mode.is_weight_only()

    try:
        model_cfg = model.config
    except AttributeError:
        model_cfg = model_config

    quant_map = {
        ColumnLinear: WeightOnlyQuantColumnLinear,
        RowLinear: WeightOnlyQuantRowLinear,
        Embedding: WeightOnlyQuantEmbedding,
    }

    def preprocess_init_params(init_params, name, module):
        init_params["quant_mode"] = quant_config.quant_mode
        if isinstance(module, ColumnLinear):
            module_name = name.rsplit('.', 1)[-1]
            init_params["transb"] = module_name == "lm_head"
        if "tp_rank" in init_params:
            init_params["tp_rank"] = model_cfg.mapping.tp_rank

    model = quantize_layers(
        model,
        quant_config,
        quant_map,
        preprocess_init_params,
    )
    return model


def weight_only_groupwise_quantize(model,
                                   quant_config: QuantConfig,
                                   model_config=None):
    assert quant_config.quant_mode.is_weight_only()

    try:
        model_cfg = model.config
    except AttributeError:
        model_cfg = model_config

    quant_map = {
        ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear,
        RowLinear: WeightOnlyGroupwiseQuantRowLinear,
        MixtureOfExperts: MixtureOfExperts,
    }

    def preprocess_init_params(init_params, name, module):
        init_params["group_size"] = quant_config.group_size
        init_params["pre_quant_scale"] = quant_config.pre_quant_scale
        init_params["zero"] = quant_config.has_zero_point
        init_params[
            "use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ
        init_params[
            "use_int8_weight"] = quant_config.quant_algo == QuantAlgo.W8A16_GPTQ
        if "tp_rank" in init_params:
            init_params["tp_rank"] = model_cfg.mapping.tp_rank

    model = quantize_layers(
        model,
        quant_config,
        quant_map,
        preprocess_init_params,
    )
    return model


def smooth_quantize_ootb(
    model,
    quant_config: QuantConfig,
):
    quant_map = {
        ColumnLinear: Int8SmoothQuantLinear,
        RowLinear: Int8SmoothQuantRowLinear,
    }

    model = quantize_layers(
        model,
        quant_config,
        quant_map,
    )
    return model


def smooth_quantize_plugin(model, quant_mode):
    quant_map = {
        RmsNorm: SmoothQuantRmsNorm,
        LayerNorm: SmoothQuantLayerNorm,
        GatedMLP: SmoothQuantGatedMLP,
        MLP: SmoothQuantMLP,
        Attention: SmoothQuantAttention,
    }
    for name, layer, parent in model.named_modules_with_parent():
        layer_name = name.rsplit('.', 1)[-1]
        if layer_name in ['ln_f', 'ln_embed']:
            continue

        quant_cls = None
        for cls in quant_map:
            if isinstance(layer, cls):
                quant_cls = quant_map[cls]
                break

        if quant_cls is None:
            continue

        init_params = get_init_params(layer, quant_cls)
        init_params["quant_mode"] = quant_mode
        if isinstance(layer, Attention):
            init_params[
                "num_attention_heads"] = layer.num_attention_heads * layer.tp_size
        quant_layer = quant_cls(**init_params)
        if parent is not None:
            setattr(parent, layer_name, quant_layer)
        else:
            model = quant_layer

    setattr(model, 'quant_mode', quant_mode)
    return model


def smooth_quantize(model, quant_config: QuantConfig):
    assert quant_config.quant_mode.has_act_and_weight_quant()
    if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST:
        return smooth_quantize_plugin(model, quant_config.quant_mode)
    else:
        return smooth_quantize_ootb(model, quant_config)


def fp8_quantize(model, quant_config: QuantConfig):
    assert quant_config.quant_mode.has_fp8_qdq()

    quant_map = {
        ColumnLinear: FP8Linear,
        RowLinear: FP8RowLinear,
        MixtureOfExperts: MixtureOfExperts,
    }

    model = quantize_layers(
        model,
        quant_config,
        quant_map,
    )
    return model


def fp8_rowwise_quantize(model, quant_config: QuantConfig):
    assert quant_config.quant_mode.has_fp8_rowwise()

    quant_cls_map = {
        RmsNorm: Fp8RowwiseRmsNorm,
        LayerNorm: Fp8RowwiseLayerNorm,
        GatedMLP: Fp8RowwiseGatedMLP,
        MLP: Fp8RowwiseMLP,
        Attention: Fp8RowwiseAttention,
    }

    exclude_modules = quant_config.exclude_modules
    if exclude_modules is None:
        exclude_modules = []
    # Always exclude these modules for FP8 rowwise
    exclude_modules = list(
        set(exclude_modules + ['*ln_f', '*ln_embed', '*lm_head']))

    def extract_layer_idx(name):
        ss = name.split('.')
        for s in ss:
            if s.isdigit():
                return int(s)
        return None

    # Meta's LLaMA 3.1 recipe:
    # (1) Skip quantization for the first and last Transformer layers
    # (2) Skip quantization for the Attention layers
    if quant_config.use_meta_recipe:
        exclude_modules.extend(['*input_layernorm', '*attention'])

    for name, layer, parent in model.named_modules_with_parent():
        module_name = name.rsplit('.', 1)[-1]

        if quant_config.use_meta_recipe:
            local_layer_idx = extract_layer_idx(name)
            mapping = model.config.mapping
            layers_range = mapping.pp_layers(model.config.num_hidden_layers)
            if mapping.is_first_pp_rank() and local_layer_idx == 0:
                continue
            if mapping.is_last_pp_rank(
            ) and local_layer_idx == len(layers_range) - 1:
                continue

        quant_cls = None
        for cls in quant_cls_map:
            if isinstance(layer, cls):
                quant_cls = quant_cls_map[cls]
                break
        if quant_cls is None:
            continue

        is_excluded = False
        for exclude_module in exclude_modules:
            if fnmatch.fnmatchcase(name, exclude_module):
                is_excluded = True
                break
        if is_excluded:
            continue

        init_params = get_init_params(layer, quant_cls)
        init_params["quant_mode"] = quant_config.quant_mode
        if isinstance(layer, Attention):
            init_params[
                "num_attention_heads"] = layer.num_attention_heads * layer.tp_size
        quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val)
        if parent is not None:
            setattr(parent, module_name, quant_layer)
        else:
            model = quant_layer

    setattr(model, 'quant_mode', quant_config.quant_mode)
    return model


# TODO: These functions should be moved to ModelOpt.
def qserve_quantize_weight_per_group(linear_weight: torch.HalfTensor,
                                     s1_scales: torch.FloatTensor,
                                     s2_scales: torch.FloatTensor,
                                     s2_szeros: torch.FloatTensor,
                                     group_size: int) -> torch.CharTensor:
    out_features = linear_weight.shape[0]
    in_features = linear_weight.shape[1]

    # Step 1: Quantize the weights to int8
    linear_weight = linear_weight.div(
        s1_scales.reshape(out_features, 1).to(linear_weight.device))
    linear_weight = linear_weight.round()
    # assert linear_weight.min() >= -119 and linear_weight.max() <= 119, "Stage 1: Quantized weight out of range" # 119 is the "magic" number
    assert (linear_weight.min() >= -128 and linear_weight.max()
            <= 127), "Stage 1: Quantized weight out of range"

    # Step 2: Quantize the weights to int4
    linear_weight = linear_weight.reshape(out_features,
                                          in_features // group_size, group_size)
    s2_szeros = s2_szeros.reshape(out_features, in_features // group_size,
                                  1).to(torch.float16).to(linear_weight.device)
    s2_scales = s2_scales.reshape(out_features, in_features // group_size,
                                  1).to(torch.float16).to(linear_weight.device)
    linear_weight = linear_weight.add(s2_szeros).div(s2_scales).round()
    assert (linear_weight.min() >= 0 and linear_weight.max()
            <= 15), "Stage 2: Quantized weight out of range"

    qweight = linear_weight.reshape(out_features, in_features).to(torch.int8)
    return qweight


def qserve_quantize_weight_per_channel(
        linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor,
        s1_szeros: torch.FloatTensor) -> torch.CharTensor:
    out_features = linear_weight.shape[0]
    in_features = linear_weight.shape[1]

    # Step 1: Quantize the weights to int4
    s1_scales = s1_scales.reshape(out_features, 1).to(linear_weight.device)
    s1_szeros = s1_szeros.reshape(out_features, 1).to(linear_weight.device)

    qweight = linear_weight.add(s1_szeros).div(s1_scales).round()
    assert (qweight.min() >= 0
            and qweight.max() <= 15), "Quantized weight out of range"

    return qweight.reshape(out_features, in_features).to(torch.int8)


# Pack the quantized weights, scales and zeros and apply the reordering required by QServe kernels.
# Return: processed [qweight, s1_scales, s2_scales, s2_zeros]
def qserve_pack_reorder_per_group(qweight: torch.CharTensor,
                                  s1_scales: torch.FloatTensor,
                                  s2_scales: torch.FloatTensor,
                                  s2_szeros: torch.FloatTensor, group_size):
    out_features = qweight.shape[0]
    in_features = qweight.shape[1]

    outputs = []

    s1_scales = s1_scales.reshape(out_features).to(torch.float16)
    s2_szeros = s2_szeros.reshape(out_features,
                                  in_features // group_size).to(torch.int8)
    s2_scales = s2_scales.reshape(out_features,
                                  in_features // group_size).to(torch.int8)

    # Step 3: Pack the quantized weights to real quantized weights
    # ---- Repack the weight ---- #
    assert qweight.dtype == torch.int8
    # pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
    W_unpack_reorder = (qweight.reshape(
        out_features // 32,
        2,
        2,
        8,
        in_features // 32,
        2,
        4,
        4,
    ).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
    W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
                                                 4).contiguous().to(torch.int8))
    # B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
    # [16, 0, 17, 1, ...]
    W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
                                                                           0]
    W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
                                                  in_features // 32, 32, 16)
    W_unpack_repacked = W_unpack_repacked.reshape(out_features,
                                                  in_features // 2)

    outputs.append(W_unpack_repacked)

    # for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm
    # ---- Pack the scales ---- #
    outputs.append(s1_scales.reshape(out_features))

    s2_scales = (s2_scales.reshape(out_features, in_features //
                                   group_size).transpose(0, 1).contiguous())
    s2_scales = s2_scales.reshape(in_features // group_size, out_features // 32,
                                  32)
    s2_scales = (s2_scales.reshape(in_features // group_size,
                                   out_features // 32, 4,
                                   8).transpose(-2, -1).contiguous())
    s2_scales = s2_scales.reshape(in_features // group_size,
                                  out_features).contiguous()
    outputs.append(s2_scales)

    # ---- Pack the zeros ---- #
    s2_szeros = (s2_szeros.reshape(out_features, in_features //
                                   group_size).transpose(0, 1).contiguous())
    s2_szeros = s2_szeros.reshape(in_features // group_size, out_features // 32,
                                  32)
    s2_szeros = (s2_szeros.reshape(in_features // group_size,
                                   out_features // 32, 4,
                                   8).transpose(-2, -1).contiguous())
    s2_szeros = (s2_szeros.reshape(in_features // group_size,
                                   out_features).contiguous())

    # (q - s2_zeros) * s2_scales = q * s2_scales - s2_zeros * s2_scales,
    # We convert the s2_zeros -> -s2_zeros * s2_scales
    s2_szeros = (-s2_szeros).int()  # It has been pre-scaled in DeepCompressor
    s2_szeros = s2_szeros.to(torch.int8)

    outputs.append(s2_szeros)

    return outputs


def qserve_pack_reorder_per_channel(qweight: torch.CharTensor,
                                    s1_scales: torch.FloatTensor,
                                    s1_szeros: torch.FloatTensor):
    out_features = qweight.shape[0]
    in_features = qweight.shape[1]

    outputs = []

    # ---- Repack the weight ---- #
    assert qweight.dtype == torch.int8
    # pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4)
    W_unpack_reorder = (qweight.reshape(
        out_features // 32,
        2,
        2,
        8,
        in_features // 32,
        2,
        4,
        4,
    ).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous())
    W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7,
                                                 4).contiguous())
    # B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous()
    # [16, 0, 17, 1, ...]
    W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[...,
                                                                           0]
    W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32,
                                                  in_features // 32, 32, 16)
    W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features //
                                                  2).contiguous()

    outputs.append(W_unpack_repacked)

    # ---- Pack the scales and zeros ---- #
    s1_scales = s1_scales.reshape(out_features).contiguous()
    outputs.append(s1_scales.half())

    s1_szeros = s1_szeros.reshape(out_features).contiguous().half()
    outputs.append(s1_szeros)

    return outputs


# TODO: Duplicates smooth_quantize and quantize_layers
def qserve_quantize(model, quant_config: QuantConfig):
    quant_mode = quant_config.quant_mode
    assert quant_config.quant_mode.is_qserve_w4a8()

    quant_map = {
        RmsNorm: QServeRmsNorm,
        LayerNorm: QServeRmsNorm,
        GatedMLP: QServeGatedMLP,
        MLP: QServeMLP,
        Attention: QServeAttention,
    }

    for name, layer, parent in model.named_modules_with_parent():
        layer_name = name.rsplit('.', 1)[-1]
        if layer_name in ['ln_f', 'ln_embed']:
            continue

        quant_cls = None
        for cls in quant_map:
            if isinstance(layer, cls):
                quant_cls = quant_map[cls]
                break

        if quant_cls is None:
            continue

        init_params = get_init_params(layer, quant_cls)
        init_params["quant_mode"] = quant_mode
        if isinstance(layer, Attention):
            init_params[
                "num_attention_heads"] = layer.num_attention_heads * layer.tp_size
        quant_layer = quant_cls(**init_params)
        if parent is not None:
            setattr(parent, layer_name, quant_layer)
        else:
            model = quant_layer

    setattr(model, 'quant_mode', quant_mode)
    return model


def fp4_quantize(model, quant_config: QuantConfig):
    assert quant_config.quant_mode.has_nvfp4()
    quant_map = {
        ColumnLinear: FP4Linear,
        RowLinear: FP4RowLinear,
        MixtureOfExperts: MixtureOfExperts,
    }

    model = quantize_layers(
        model,
        quant_config,
        quant_map,
    )
    return model


# Now consider the kv cache is enabled for all layers
def kv_cache_quantize(model):
    for name, module in model.named_modules():
        if isinstance(module,
                      (Attention, SmoothQuantAttention, Fp8RowwiseAttention)):
            # for dequant
            module.kv_cache_scaling_factor = Parameter(shape=(1, ),
                                                       dtype='float32')
            # for quant
            module.kv_cache_rcp_scaling_factor = Parameter(shape=(1, ),
                                                           dtype='float32')
    return model


def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]):

    for name, module, parent in model.named_modules_with_parent():

        if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
            layer_quant_mode = quant_config.layer_quant_mode(name)
        else:
            layer_quant_mode = quant_config.layer_quant_mode
        if layer_quant_mode == QuantMode(0):
            continue

        layer_quant_cfg = quant_config._get_quant_cfg(name)

        if layer_quant_mode.has_fp8_qdq():
            module = fp8_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.has_fp8_rowwise():
            module = fp8_rowwise_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.is_qserve_w4a8():
            module = qserve_quantize(module, quant_config)
        elif layer_quant_mode.has_nvfp4():
            module = fp4_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.has_act_and_weight_quant():
            module = smooth_quantize(module, layer_quant_cfg)
        elif layer_quant_mode.is_weight_only():
            if layer_quant_mode.has_per_group_scaling():
                module = weight_only_groupwise_quantize(module, layer_quant_cfg,
                                                        model.config)
            else:
                module = weight_only_quantize(module, layer_quant_cfg,
                                              model.config)

        if parent is not None:  # for per layer
            module_name = name.rsplit('.', 1)[-1]
            setattr(parent, module_name, module)
        else:  # for all layer
            model = module
            break

    if quant_config.quant_mode.has_kv_cache_quant():
        model = kv_cache_quantize(model)

    setattr(model, 'quant_mode', quant_config.quant_mode)
    return model
